# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from collections import OrderedDict
from pathlib import Path

import numpy as np

from ..._fiff.meas_info import create_info
from ...evoked import EvokedArray
from ...utils import fill_doc, logger, verbose


@fill_doc
@verbose
def read_evoked_besa(fname, verbose=None):
    """Reader function for BESA ``.avr`` or ``.mul`` files.

    When a ``.elp`` sidecar file is present, it will be used to determine
    electrode information.

    Parameters
    ----------
    fname : path-like
        Path to the ``.avr`` or ``.mul`` file.
    %(verbose)s

    Returns
    -------
    ev : Evoked
        The evoked data in the .avr or .mul file.
    """
    fname = Path(fname)
    if fname.suffix == ".avr":
        return _read_evoked_besa_avr(fname, verbose)
    elif fname.suffix == ".mul":
        return _read_evoked_besa_mul(fname, verbose)
    else:
        raise ValueError("Filename must end in either .avr or .mul")


@verbose
def _read_evoked_besa_avr(fname, verbose):
    """Create EvokedArray from a BESA .avr file."""
    with open(fname) as f:
        header = f.readline().strip()

        # There are two versions of .avr files. The old style, generated by
        # BESA 1, 2 and 3 does not define Nchan and does not have channel names
        # in the file.
        new_style = "Nchan=" in header
        if new_style:
            ch_names = f.readline().strip().split()
        else:
            ch_names = None

    fields = _parse_header(header)
    data = np.loadtxt(fname, skiprows=2 if new_style else 1, ndmin=2)
    ch_types = _read_elp_sidecar(fname)

    # Consolidate channel names
    if new_style:
        if len(ch_names) != len(data):
            raise RuntimeError(
                "Mismatch between the number of channel names defined in "
                f"the .avr file ({len(ch_names)}) and the number of rows "
                f"in the data matrix ({len(data)})."
            )
    else:
        # Determine channel names from the .elp sidecar file
        if ch_types is not None:
            ch_names = list(ch_types.keys())
            if len(ch_names) != len(data):
                raise RuntimeError(
                    "Mismatch between the number of channels "
                    f"defined in the .avr file ({len(data)}) "
                    f"and .elp file ({len(ch_names)})."
                )
        else:
            logger.info(
                "No .elp file found and no channel names present in "
                "the .avr file. Falling back to generic names. "
            )
            ch_names = [f"CH{i + 1:02d}" for i in range(len(data))]

    # Consolidate channel types
    if ch_types is None:
        logger.info("Marking all channels as EEG.")
        ch_types = ["eeg"] * len(ch_names)
    else:
        ch_types = [ch_types[ch] for ch in ch_names]

    # Go over all the header fields and make sure they are all defined to
    # something sensible.
    if "Npts" in fields:
        fields["Npts"] = int(fields["Npts"])
        if fields["Npts"] != data.shape[1]:
            logger.warn(
                f"The size of the data matrix ({data.shape}) does not "
                f'match the "Npts" field ({fields["Npts"]}).'
            )
    if "Nchan" in fields:
        fields["Nchan"] = int(fields["Nchan"])
        if fields["Nchan"] != data.shape[0]:
            logger.warn(
                f"The size of the data matrix ({data.shape}) does not "
                f'match the "Nchan" field ({fields["Nchan"]}).'
            )
    if "DI" in fields:
        fields["DI"] = float(fields["DI"])
    else:
        raise RuntimeError(
            'No "DI" field present. Could not determine sampling frequency.'
        )
    if "TSB" in fields:
        fields["TSB"] = float(fields["TSB"])
    else:
        fields["TSB"] = 0
    if "SB" in fields:
        fields["SB"] = float(fields["SB"])
    else:
        fields["SB"] = 1.0
    if "SegmentName" not in fields:
        fields["SegmentName"] = ""

    # Build the Evoked object based on the header fields.
    info = create_info(ch_names, sfreq=1000 / fields["DI"], ch_types="eeg")
    return EvokedArray(
        data / fields["SB"] / 1e6,
        info,
        tmin=fields["TSB"] / 1000,
        comment=fields["SegmentName"],
        verbose=verbose,
    )


@verbose
def _read_evoked_besa_mul(fname, verbose):
    """Create EvokedArray from a BESA .mul file."""
    with open(fname) as f:
        header = f.readline().strip()
        ch_names = f.readline().strip().split()

    fields = _parse_header(header)
    data = np.loadtxt(fname, skiprows=2, ndmin=2)

    if len(ch_names) != data.shape[1]:
        raise RuntimeError(
            "Mismatch between the number of channel names "
            f"defined in the .mul file ({len(ch_names)}) "
            "and the number of columns in the data matrix "
            f"({data.shape[1]})."
        )

    # Consolidate channel types
    ch_types = _read_elp_sidecar(fname)
    if ch_types is None:
        logger.info("Marking all channels as EEG.")
        ch_types = ["eeg"] * len(ch_names)
    else:
        ch_types = [ch_types[ch] for ch in ch_names]

    # Go over all the header fields and make sure they are all defined to
    # something sensible.
    if "TimePoints" in fields:
        fields["TimePoints"] = int(fields["TimePoints"])
        if fields["TimePoints"] != data.shape[0]:
            logger.warn(
                f"The size of the data matrix ({data.shape}) does not "
                f'match the "TimePoints" field ({fields["TimePoints"]}).'
            )
    if "Channels" in fields:
        fields["Channels"] = int(fields["Channels"])
        if fields["Channels"] != data.shape[1]:
            logger.warn(
                f"The size of the data matrix ({data.shape}) does not "
                f'match the "Channels" field ({fields["Channels"]}).'
            )
    if "SamplingInterval[ms]" in fields:
        fields["SamplingInterval[ms]"] = float(fields["SamplingInterval[ms]"])
    else:
        raise RuntimeError(
            'No "SamplingInterval[ms]" field present. Could '
            "not determine sampling frequency."
        )
    if "BeginSweep[ms]" in fields:
        fields["BeginSweep[ms]"] = float(fields["BeginSweep[ms]"])
    else:
        fields["BeginSweep[ms]"] = 0.0
    if "Bins/uV" in fields:
        fields["Bins/uV"] = float(fields["Bins/uV"])
    else:
        fields["Bins/uV"] = 1
    if "SegmentName" not in fields:
        fields["SegmentName"] = ""

    # Build the Evoked object based on the header fields.
    info = create_info(
        ch_names, sfreq=1000 / fields["SamplingInterval[ms]"], ch_types=ch_types
    )
    return EvokedArray(
        data.T / fields["Bins/uV"] / 1e6,
        info,
        tmin=fields["BeginSweep[ms]"] / 1000,
        comment=fields["SegmentName"],
        verbose=verbose,
    )


def _parse_header(header):
    """Parse an .avr or .mul header string into name/val pairs.

    The header line looks like:
        Npts= 256   TSB= 0.000 DI= 4.000000 SB= 1.000 SC= 200.0 Nchan= 27
    No consistent use of separation chars, so parsing this is a bit iffy.

    Parameters
    ----------
    header : str
        The first line of the file.

    Returns
    -------
    fields : dict
        The parsed header fields
    """
    parts = header.split()  # Splits on one or more spaces
    name_val_pairs = zip(parts[::2], parts[1::2])
    return dict((name.replace("=", ""), val) for name, val in name_val_pairs)


def _read_elp_sidecar(fname):
    """Read a possible .elp sidecar file with electrode information.

    The reason we don't use the read_custom_montage for this is that we are
    interested in the channel types, which a DigMontage object does not provide
    us.

    Parameters
    ----------
    fname : Path
        The path of the .avr or .mul file. The corresponding .elp file will be
        derived from this path.

    Returns
    -------
    ch_type : OrderedDict | None
        If the sidecar file exists, return a dictionary mapping channel names
        to channel types. Otherwise returns ``None``.
    """
    fname_elp = fname.parent / (fname.stem + ".elp")
    if not fname_elp.exists():
        logger.info(f"No {fname_elp} file present containing electrode information.")
        return None

    logger.info(f"Reading electrode names and types from {fname_elp}")
    ch_types = OrderedDict()
    with open(fname_elp) as f:
        lines = f.readlines()
        if len(lines[0].split()) > 3:
            # Channel types present
            for line in lines:
                ch_type, ch_name = line.split()[:2]
                ch_types[ch_name] = ch_type.lower()
        else:
            # No channel types present
            logger.info(
                "No channel types present in .elp file. Marking all channels as EEG."
            )
            for line in lines:
                ch_name = line.split()[:1]
                ch_types[ch_name] = "eeg"
    return ch_types
